{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5wFF5JFyD2Ki"
      },
      "source": [
        "#### Copyright 2019 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uf6NouXxDqGk"
      },
      "outputs": [],
      "source": [
        "# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ORy-KvWXGXBo"
      },
      "source": [
        "# Exploring the TF-Hub CORD-19 Swivel Embeddings\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/cord_19_embeddings\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/cord_19_embeddings.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/hub/blob/master/examples/colab/cord_19_embeddings.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/hub/examples/colab/cord_19_embeddings.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://tfhub.dev/tensorflow/cord-19/swivel-128d/1\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9VusdTAH0isl"
      },
      "source": [
        "The CORD-19 Swivel text embedding module from TF-Hub (https://tfhub.dev/tensorflow/cord-19/swivel-128d/1)\n",
        " was built to support researchers analyzing natural languages text related to COVID-19.\n",
        "These embeddings were trained on the titles, authors, abstracts, body texts, and\n",
        "reference titles of articles in the [CORD-19 dataset](https://pages.semanticscholar.org/coronavirus-research).\n",
        "\n",
        "In this colab we will:\n",
        "- Analyze semantically similar words in the embedding space\n",
        "- Train a classifier on the SciCite dataset using the CORD-19 embeddings\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L69VQv2Z0isl"
      },
      "source": [
        "## Setup\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ym2nXOPuPV__"
      },
      "outputs": [],
      "source": [
        "import functools\n",
        "import itertools\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import pandas as pd\n",
        "\n",
        "import tensorflow.compat.v1 as tf\n",
        "tf.disable_eager_execution()\n",
        "tf.logging.set_verbosity('ERROR')\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "try:\n",
        "  from google.colab import data_table\n",
        "  def display_df(df):\n",
        "    return data_table.DataTable(df, include_index=False)\n",
        "except ModuleNotFoundError:\n",
        "  # If google-colab is not available, just display the raw DataFrame\n",
        "  def display_df(df):\n",
        "    return df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VgRRf2I7tER"
      },
      "source": [
        "# Analyze the embeddings\n",
        "\n",
        "Let's start off by analyzing the embedding by calculating and plotting a correlation matrix between different terms. If the embedding learned to successfully capture the meaning of different words, the embedding vectors of semantically similar words should be close together. Let's take a look at some COVID-19 related terms."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HNN_9bBKSLHU"
      },
      "outputs": [],
      "source": [
        "# Use the inner product between two embedding vectors as the similarity measure\n",
        "def plot_correlation(labels, features):\n",
        "  corr = np.inner(features, features)\n",
        "  corr /= np.max(corr)\n",
        "  sns.heatmap(corr, xticklabels=labels, yticklabels=labels)\n",
        "\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  # Load the module\n",
        "  query_input = tf.placeholder(tf.string)\n",
        "  module = hub.Module('https://tfhub.dev/tensorflow/cord-19/swivel-128d/1')\n",
        "  embeddings = module(query_input)\n",
        "\n",
        "  with tf.train.MonitoredTrainingSession() as sess:\n",
        "\n",
        "    # Generate embeddings for some terms\n",
        "    queries = [\n",
        "        # Related viruses\n",
        "        \"coronavirus\", \"SARS\", \"MERS\",\n",
        "        # Regions\n",
        "        \"Italy\", \"Spain\", \"Europe\",\n",
        "        # Symptoms\n",
        "        \"cough\", \"fever\", \"throat\"\n",
        "    ]\n",
        "\n",
        "    features = sess.run(embeddings, feed_dict={query_input: queries})\n",
        "    plot_correlation(queries, features)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bg-PGqtm8B7K"
      },
      "source": [
        "We can see that the embedding successfully captured the meaning of the different terms. Each word is similar to the other words of its cluster (i.e. \"coronavirus\" highly correlates with \"SARS\" and \"MERS\"), while they are different from terms of other clusters (i.e. the similarity between \"SARS\" and \"Spain\" is close to 0).\n",
        "\n",
        "Now let's see how we can use these embeddings to solve a specific task."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "idJ1jFmH7xMa"
      },
      "source": [
        "## SciCite: Citation Intent Classification\n",
        "\n",
        "This section shows how one can use the embedding for downstream tasks such as text classification. We'll use the [SciCite dataset](https://www.tensorflow.org/datasets/catalog/scicite) from TensorFlow Datasets to classify citation intents in academic papers. Given a sentence with a citation from an academic paper, classify whether the main intent of the citation is as background information, use of methods, or comparing results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "-FB19HLfVp2V"
      },
      "outputs": [],
      "source": [
        "#@title Set up the dataset from TFDS\n",
        "\n",
        "class Dataset:\n",
        "  \"\"\"Build a dataset from a TFDS dataset.\"\"\"\n",
        "  def __init__(self, tfds_name, feature_name, label_name):\n",
        "    self.dataset_builder = tfds.builder(tfds_name)\n",
        "    self.dataset_builder.download_and_prepare()\n",
        "    self.feature_name = feature_name\n",
        "    self.label_name = label_name\n",
        "  \n",
        "  def get_data(self, for_eval):\n",
        "    splits = THE_DATASET.dataset_builder.info.splits\n",
        "    if tfds.Split.TEST in splits:\n",
        "      split = tfds.Split.TEST if for_eval else tfds.Split.TRAIN\n",
        "    else:\n",
        "      SPLIT_PERCENT = 80\n",
        "      split = \"train[{}%:]\".format(SPLIT_PERCENT) if for_eval else \"train[:{}%]\".format(SPLIT_PERCENT)\n",
        "    return self.dataset_builder.as_dataset(split=split)\n",
        "\n",
        "  def num_classes(self):\n",
        "    return self.dataset_builder.info.features[self.label_name].num_classes\n",
        "\n",
        "  def class_names(self):\n",
        "    return self.dataset_builder.info.features[self.label_name].names\n",
        "\n",
        "  def preprocess_fn(self, data):\n",
        "    return data[self.feature_name], data[self.label_name]\n",
        "\n",
        "  def example_fn(self, data):\n",
        "    feature, label = self.preprocess_fn(data)\n",
        "    return {'feature': feature, 'label': label}, label\n",
        "\n",
        "\n",
        "def get_example_data(dataset, num_examples, **data_kw):\n",
        "  \"\"\"Show example data\"\"\"\n",
        "  with tf.Session() as sess:\n",
        "    batched_ds = dataset.get_data(**data_kw).take(num_examples).map(dataset.preprocess_fn).batch(num_examples)\n",
        "    it = tf.data.make_one_shot_iterator(batched_ds).get_next()\n",
        "    data = sess.run(it)\n",
        "  return data\n",
        "\n",
        "\n",
        "TFDS_NAME = 'scicite' #@param {type: \"string\"}\n",
        "TEXT_FEATURE_NAME = 'string' #@param {type: \"string\"}\n",
        "LABEL_NAME = 'label' #@param {type: \"string\"}\n",
        "THE_DATASET = Dataset(TFDS_NAME, TEXT_FEATURE_NAME, LABEL_NAME)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "CVjyBD0ZPh4Z"
      },
      "outputs": [],
      "source": [
        "#@title Let's take a look at a few labeled examples from the training set\n",
        "NUM_EXAMPLES = 20  #@param {type:\"integer\"}\n",
        "data = get_example_data(THE_DATASET, NUM_EXAMPLES, for_eval=False)\n",
        "display_df(\n",
        "    pd.DataFrame({\n",
        "        TEXT_FEATURE_NAME: [ex.decode('utf8') for ex in data[0]],\n",
        "        LABEL_NAME: [THE_DATASET.class_names()[x] for x in data[1]]\n",
        "    }))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "65s9UpYJ_1ct"
      },
      "source": [
        "## Training a citaton intent classifier\n",
        "\n",
        "We'll train a classifier on the [SciCite dataset](https://www.tensorflow.org/datasets/catalog/scicite) using an Estimator. Let's set up the input_fns to read the dataset into the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "OldapWmKSGsW"
      },
      "outputs": [],
      "source": [
        "def preprocessed_input_fn(for_eval):\n",
        "  data = THE_DATASET.get_data(for_eval=for_eval)\n",
        "  data = data.map(THE_DATASET.example_fn, num_parallel_calls=1)\n",
        "  return data\n",
        "\n",
        "\n",
        "def input_fn_train(params):\n",
        "  data = preprocessed_input_fn(for_eval=False)\n",
        "  data = data.repeat(None)\n",
        "  data = data.shuffle(1024)\n",
        "  data = data.batch(batch_size=params['batch_size'])\n",
        "  return data\n",
        "\n",
        "\n",
        "def input_fn_eval(params):\n",
        "  data = preprocessed_input_fn(for_eval=True)\n",
        "  data = data.repeat(1)\n",
        "  data = data.batch(batch_size=params['batch_size'])\n",
        "  return data\n",
        "\n",
        "\n",
        "def input_fn_predict(params):\n",
        "  data = preprocessed_input_fn(for_eval=True)\n",
        "  data = data.batch(batch_size=params['batch_size'])\n",
        "  return data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KcrmWUkVKg2u"
      },
      "source": [
        "Let's build a model which use the CORD-19 embeddings with a classification layer on top."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ff0uKqJCA9zh"
      },
      "outputs": [],
      "source": [
        "def model_fn(features, labels, mode, params):\n",
        "  # Embed the text\n",
        "  embed = hub.Module(params['module_name'], trainable=params['trainable_module'])\n",
        "  embeddings = embed(features['feature'])\n",
        "\n",
        "  # Add a linear layer on top\n",
        "  logits = tf.layers.dense(\n",
        "      embeddings, units=THE_DATASET.num_classes(), activation=None)\n",
        "  predictions = tf.argmax(input=logits, axis=1)\n",
        "\n",
        "  if mode == tf.estimator.ModeKeys.PREDICT:\n",
        "    return tf.estimator.EstimatorSpec(\n",
        "        mode=mode,\n",
        "        predictions={\n",
        "            'logits': logits,\n",
        "            'predictions': predictions,\n",
        "            'features': features['feature'],\n",
        "            'labels': features['label']\n",
        "        })\n",
        "  \n",
        "  # Set up a multi-class classification head\n",
        "  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
        "      labels=labels, logits=logits)\n",
        "  loss = tf.reduce_mean(loss)\n",
        "\n",
        "  if mode == tf.estimator.ModeKeys.TRAIN:\n",
        "    optimizer = tf.train.GradientDescentOptimizer(learning_rate=params['learning_rate'])\n",
        "    train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())\n",
        "    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)\n",
        "\n",
        "  elif mode == tf.estimator.ModeKeys.EVAL:\n",
        "    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)\n",
        "    precision = tf.metrics.precision(labels=labels, predictions=predictions)\n",
        "    recall = tf.metrics.recall(labels=labels, predictions=predictions)\n",
        "\n",
        "    return tf.estimator.EstimatorSpec(\n",
        "        mode=mode,\n",
        "        loss=loss,\n",
        "        eval_metric_ops={\n",
        "            'accuracy': accuracy,\n",
        "            'precision': precision,\n",
        "            'recall': recall,\n",
        "        })\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "yZUclu8xBYlj"
      },
      "outputs": [],
      "source": [
        "#@title Hyperparmeters { run: \"auto\" }\n",
        "\n",
        "EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/1'  #@param {type: \"string\"}\n",
        "TRAINABLE_MODULE = False  #@param {type: \"boolean\"}\n",
        "STEPS =   8000#@param {type: \"integer\"}\n",
        "EVAL_EVERY = 200  #@param {type: \"integer\"}\n",
        "BATCH_SIZE = 10  #@param {type: \"integer\"}\n",
        "LEARNING_RATE = 0.01  #@param {type: \"number\"}\n",
        "\n",
        "params = {\n",
        "    'batch_size': BATCH_SIZE,\n",
        "    'learning_rate': LEARNING_RATE,\n",
        "    'module_name': EMBEDDING,\n",
        "    'trainable_module': TRAINABLE_MODULE\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "weZKWK-pLBll"
      },
      "source": [
        "## Train and evaluate the model\n",
        "\n",
        "Let's train and evaluate the model to see the performance on the SciCite task"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cO1FWkZW2WS9"
      },
      "outputs": [],
      "source": [
        "estimator = tf.estimator.Estimator(functools.partial(model_fn, params=params))\n",
        "metrics = []\n",
        "\n",
        "for step in range(0, STEPS, EVAL_EVERY):\n",
        "  estimator.train(input_fn=functools.partial(input_fn_train, params=params), steps=EVAL_EVERY)\n",
        "  step_metrics = estimator.evaluate(input_fn=functools.partial(input_fn_eval, params=params))\n",
        "  print('Global step {}: loss {:.3f}, accuracy {:.3f}'.format(step, step_metrics['loss'], step_metrics['accuracy']))\n",
        "  metrics.append(step_metrics)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RUNGAeyf1ygC"
      },
      "outputs": [],
      "source": [
        "global_steps = [x['global_step'] for x in metrics]\n",
        "fig, axes = plt.subplots(ncols=2, figsize=(20,8))\n",
        "\n",
        "for axes_index, metric_names in enumerate([['accuracy', 'precision', 'recall'],\n",
        "                                            ['loss']]):\n",
        "  for metric_name in metric_names:\n",
        "    axes[axes_index].plot(global_steps, [x[metric_name] for x in metrics], label=metric_name)\n",
        "  axes[axes_index].legend()\n",
        "  axes[axes_index].set_xlabel(\"Global Step\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1biWylvB6ayg"
      },
      "source": [
        "We can see that the loss quickly decreases while especially the accuracy rapidly increases. Let's plot some examples to check how the prediction relates to the true labels:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zK_NJXtoyG2o"
      },
      "outputs": [],
      "source": [
        "predictions = estimator.predict(functools.partial(input_fn_predict, params))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nlxFER_Oriam"
      },
      "outputs": [],
      "source": [
        "first_10_predictions = list(itertools.islice(predictions, 10))\n",
        "\n",
        "display_df(\n",
        "  pd.DataFrame({\n",
        "      TEXT_FEATURE_NAME: [pred['features'].decode('utf8') for pred in first_10_predictions],\n",
        "      LABEL_NAME: [THE_DATASET.class_names()[pred['labels']] for pred in first_10_predictions],\n",
        "      'prediction': [THE_DATASET.class_names()[pred['predictions']] for pred in first_10_predictions]\n",
        "  }))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OSGcrkE069_Q"
      },
      "source": [
        "We can see that for this random sample, the model predicts the correct label most of the times, indicating that it can embed scientific sentences pretty well."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oLE0kCfO5CIA"
      },
      "source": [
        "# What's next?\n",
        "\n",
        "Now that you've gotten to know a bit more about the CORD-19 Swivel embeddings from TF-Hub, we encourage you to participate in the CORD-19 Kaggle competition to contribute to gaining scientific insights from COVID-19 related academic texts.\n",
        "\n",
        "* Participate in the [CORD-19 Kaggle Challenge](https://www.kaggle.com/allen-institute-for-ai/CORD-19-research-challenge)\n",
        "* Learn more about the [COVID-19 Open Research Dataset (CORD-19)](https://pages.semanticscholar.org/coronavirus-research)\n",
        "* See documentation and more about the TF-Hub embeddings at https://tfhub.dev/tensorflow/cord-19/swivel-128d/1\n",
        "* Explore the CORD-19 embedding space with the [TensorFlow Embedding Projector](http://projector.tensorflow.org/?config=https://storage.googleapis.com/tfhub-examples/tensorflow/cord-19/swivel-128d/1/tensorboard/full_projector_config.json)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "5wFF5JFyD2Ki"
      ],
      "name": "Exploring the TF-Hub CORD-19 Swivel Embeddings",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
